package com.fasterxml.jackson.module.afterburner.util;

import java.lang.reflect.Method;
import java.nio.charset.Charset;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Class loader that is needed to load generated classes.
 */
public class MyClassLoader extends ClassLoader
{
    private final static Charset UTF8 = Charset.forName("UTF-8");

    // Maps parent classloader instance and class name to the corresponding lock object.
    // N.B. this must be static because multiple instances of MyClassLoader must all use the same lock
    // when loading classes directly on the same parent.
    private final static ConcurrentHashMap<String, Object> parentParallelLockMap = new ConcurrentHashMap<>();

    /**
     * Flag that determines if we should first try to load new class
     * using parent class loader or not; this may be done to try to
     * force access to protected/package-access properties.
     */
    protected final boolean _cfgUseParentLoader;
    
    public MyClassLoader(ClassLoader parent, boolean tryToUseParent)
    {
        super(parent);
        _cfgUseParentLoader = tryToUseParent;
    }

    /**
     * Helper method called to check whether it is acceptable to create a new
     * class in package that given class is part of.
     * This is used to prevent certain class of failures, related to access
     * limitations: for example, we can not add classes in sealed packages,
     * or core Java packages (java.*).
     * 
     * @since 2.2.1
     */
    public static boolean canAddClassInPackageOf(Class<?> cls)
    {
        final Package beanPackage = cls.getPackage();
        if (beanPackage != null) {
            if (beanPackage.isSealed()) {
                return false;
            }
            String pname = beanPackage.getName();
            /* 14-Aug-2014, tatu: java.* we do not want to touch, but
             *    javax is bit trickier. For now let's 
             */
            if (pname.startsWith("java.")
                    || pname.startsWith("javax.security.")) {
                return false;
            }
        }
        return true;
    }
    
    /**
     * @param className Interface or abstract class that class to load should extend or 
     *   implement
     */
    public Class<?> loadAndResolve(ClassName className, byte[] byteCode)
            throws IllegalArgumentException
    {
        // first, try to loadAndResolve via the parent classloader, if configured to do so
        Class<?> classFromParent = loadAndResolveUsingParentClassloader(className, byteCode);
        if (classFromParent != null) {
            return classFromParent;
        }

        // fall back to loading and resolving ourselves
        synchronized (getClassLoadingLock(className.getDottedName())) {
            // First: check to see if we have loaded it ourselves
            Class<?> existingClass = findLoadedClass(className.getDottedName());
            if (existingClass != null) {
                return existingClass;
            }

            // Important: bytecode is generated with a template name (since bytecode itself
            // is used for checksum calculation) -- must be replaced now, however
            replaceName(byteCode, className.getSlashedTemplate(), className.getSlashedName());

            // Second: define a new class instance using the bytecode
            Class<?> newClass;
            try {
                newClass = defineClass(className.getDottedName(), byteCode, 0, byteCode.length);
            } catch (LinkageError e) {
                Throwable t = e;
                while (t.getCause() != null) {
                    t = t.getCause();
                }
                throw new IllegalArgumentException("Failed to load class '" + className + "': " + t.getMessage(), t);
            }
            // important: must also resolve the newly-created class.
            resolveClass(newClass);
            return newClass;
        }
    }

    /**
     * Attempt to load (and resolve) the class using the parent class loader (if it is configured and present).
     * This method will return {@code null} if the parent classloader is not configured or cannot be retrieved.
     *
     * @param className Interface or abstract class that class to load should extend or implement
     * @param byteCode  the generated bytecode for the class to load
     * @return          the loaded class, or {@code null} if the class could not be loaded on the parent classloader.
     */
    private Class<?> loadAndResolveUsingParentClassloader(ClassName className, byte[] byteCode)
    {
        ClassLoader parentClassLoader;
        if (!_cfgUseParentLoader || (parentClassLoader = getParent()) == null) {
            return null;
        }
        // N.B. The parent-class-loading locks are shared between all instances of MyClassLoader.
        // We can be confident that no attempt will be made to re-acquire *any* parent-class-loading lock instance
        // inside the synchronized region (eliminating the risk of deadlock), even if the parent class loader is also
        // an instance of MyClassLoader, because:
        //      a) this method is the only place that attempts to acquire a parent class loading lock,
        //      b) the only non-private method which calls this method and thus acquires this lock is
        //          MyClassLoader#loadAndResolve,
        //      c) nothing in the synchronized region can have the effect of calling #loadAndResolve on this
        //          or any other instance of MyClassLoader.
        synchronized (getParentClassLoadingLock(parentClassLoader, className.getDottedName())) {
            // First: check to see if the parent classloader has loaded it already
            Class<?> impl = findLoadedClassOnParent(parentClassLoader, className.getDottedName());
            if (impl != null) {
                return impl;
            }

            // Important: bytecode is generated with a template name (since bytecode itself
            // is used for checksum calculation) -- must be replaced now, however
            replaceName(byteCode, className.getSlashedTemplate(), className.getSlashedName());

            // Second: define a new class instance on the parent classloder using the bytecode
            impl = defineClassOnParent(parentClassLoader, className.getDottedName(), byteCode, 0, byteCode.length);
            // important: must also resolve the newly-created class.
            resolveClassOnParent(parentClassLoader, impl);
            return impl;
        }
    }

    /**
     * Get the class loading lock for the parent class loader for loading the named class.
     *
     * This is effectively the same implementation as ClassLoader#getClassLoadingLock, but using
     * our static parentParallelLockMap and keying off of the parent ClassLoader instance as well as
     * the class name to load.
     *
     * @param parentClassLoader     The parent ClassLoader
     * @param className             The name of the to-be-loaded class
     */
    private Object getParentClassLoadingLock(ClassLoader parentClassLoader, String className) {
        // N.B. using the canonical name and identity hash code to represent the parent class loader in the key
        // in case that ClassLoader instance (which could be anything) implements #hashCode or #toString poorly.
        // In the event of a collision here (same key, different parent class loader), we will end up using the
        // same lock only to synchronize loads of the same class on two different class loaders,
        // which shouldn't ever deadlock (see proof in #loadAndResolveUsingParentClassloader);
        // worst case is unnecessary contention for the lock.
        String key = parentClassLoader.getClass().getCanonicalName()
                + ":" + System.identityHashCode(parentClassLoader)
                + ":" + className;
        Object newLock = new Object();
        Object lock = parentParallelLockMap.putIfAbsent(key, newLock);
        if (lock == null) {
            lock = newLock;
        }
        return lock;
    }

    private Class<?> findLoadedClassOnParent(ClassLoader parentClassLoader, String className) {
        try {
            Method method = ClassLoader.class.getDeclaredMethod("findLoadedClass", String.class);
            method.setAccessible(true);
            return (Class<?>) method.invoke(parentClassLoader, className);
        } catch (Exception e) {
            String msg = String.format("Exception trying 'findLoadedClass(%s)' on parent ClassLoader '%s'",
                    className, parentClassLoader);
            Logger.getLogger(MyClassLoader.class.getName()).log(Level.FINE, msg, e);
            return null;
        }
    }

    // visible for testing
    Class<?> defineClassOnParent(ClassLoader parentClassLoader,
                                 String className,
                                 byte[] byteCode,
                                 int offset,
                                 int length) {
        try {
            Method method = ClassLoader.class.getDeclaredMethod("defineClass",
                    new Class[]{String.class, byte[].class, int.class, int.class});
            method.setAccessible(true);
            return (Class<?>) method.invoke(parentClassLoader,
                    className, byteCode, offset, length);
        } catch (Exception e) {
            String msg = String.format("Exception trying 'defineClass(%s, <bytecode>)' on parent ClassLoader '%s'",
                    className, parentClassLoader);
            Logger.getLogger(MyClassLoader.class.getName()).log(Level.FINE, msg, e);
            return null;
        }
    }

    private void resolveClassOnParent(ClassLoader parentClassLoader, Class<?> clazz) {
        try {
            Method method = ClassLoader.class.getDeclaredMethod("resolveClass", Class.class);
            method.setAccessible(true);
            method.invoke(parentClassLoader, clazz);
        } catch (Exception e) {
            String msg = String.format("Exception trying 'resolveClass(%s)' on parent ClassLoader '%s'",
                    clazz, parentClassLoader);
            Logger.getLogger(MyClassLoader.class.getName()).log(Level.FINE, msg, e);
        }
    }
    
    public static int replaceName(byte[] byteCode,
            String from, String to)
    {
        byte[] fromB = from.getBytes(UTF8);
        byte[] toB = to.getBytes(UTF8);

        final int matchLength = fromB.length;

        // sanity check
        if (matchLength != toB.length) {
            throw new IllegalArgumentException("From String '"+from
                    +"' has different length than To String '"+to+"'");
        }

        int i = 0;
        int count = 0;

        // naive; for now has to do
        main_loop:
        for (int end = byteCode.length - matchLength; i <= end; ) {
            if (byteCode[i++] == fromB[0]) {
                for (int j = 1; j < matchLength; ++j) {
                    if (fromB[j] != byteCode[i+j-1]) {
                        continue main_loop;
                    }
                }
                ++count;
                System.arraycopy(toB, 0, byteCode, i-1, matchLength);
                i += (matchLength-1);
            }
        }
        return count;
    }
}
